from rope_transformer import TransformerDecoder

a = TransformerDecoder(
    num_layers=3,  # 解码器的层数
    input_dim=2,  # 输入的维度
    hide_dim=2,  # 隐藏层维度
    n_q_heads=2,
    n_kv_heads=2,
    max_len=105,
)
print(a)
